# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Union

import torch
import torch.nn as nn
from mmengine.model import BaseModule

from mmselfsup.registry import MODELS

import numpy as np
import time
@MODELS.register_module()
class BEiTV2Head(BaseModule):
    """Pretrain Head for BEiT.

    Compute the logits and the cross entropy loss.

    Args:
        embed_dims (int): The dimension of embedding.
        num_embed (int): The number of classification types.
        loss (dict): The config of loss.
        init_cfg (dict or List[dict], optional): Initialization config dict.
            Defaults to None.
    """

    def __init__(
        self,
        embed_dims: int,
        num_embed: int,
        loss: dict,
        init_cfg: Optional[Union[dict, List[dict]]] = dict(
            type='TruncNormal', layer='Linear', std=0.02, bias=0)
    ) -> None:
        super().__init__(init_cfg=init_cfg)
        self.cls_head = nn.Linear(embed_dims, num_embed)
        self.loss = MODELS.build(loss)

    def forward(self,name, feats: torch.Tensor, feats_cls_pt: torch.Tensor,
                target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        """Generate loss.

        Args:
            feats (torch.Tensor): Features from backbone.
            feats_cls_pt (torch.Tensor) : Features from class late layers for
                pretraining.
            target (torch.Tensor): Target generated by target_generator.
            mask (torch.Tensor): Generated mask for pretraing.
        """
        # timestamp = str(int(time.time()))[-6:]

        # np.save('/dssg/home/acct-medftn/medftn/BEPT/meta/chooseData_Target/'+name[:-4]+'.npy',target.cpu().detach().numpy())
        mask = mask.flatten(1).to(torch.bool)
        target = target[mask]

        # shared cls head
        logits = self.cls_head(feats[mask])
        
        logits_cls_pt = self.cls_head(feats_cls_pt[mask])
        
        loss = self.loss((logits, logits_cls_pt), target)
        return loss
